{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!pip install -q torch-scatter -f https://data.pyg.org/whl/torch-1.10.0+cu113.html\n",
    "!pip install -q torch-sparse -f https://data.pyg.org/whl/torch-1.10.0+cu113.html\n",
    "!pip install -q git+https://github.com/pyg-team/pytorch_geometric.git"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from torch_geometric.nn import Node2Vec\n",
    "import os.path as osp\n",
    "import torch\n",
    "import matplotlib.pyplot as plt\n",
    "from sklearn.manifold import TSNE\n",
    "from torch_geometric.datasets import Planetoid\n",
    "from tqdm.notebook import tqdm\n",
    "from torch_geometric.datasets import TUDataset\n",
    "\n",
    "import numpy as np"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Node2Vec for link prediction\n",
    "\n",
    "In this tutorial, we use the node embedding produced by Node2Vec, then we compute the edge embedding $(emb(E))$ as follow:\n",
    "\n",
    "$$\n",
    "emb(E) = emb(u,v) = \\frac{1}{2}(Emb(u) + Emb(v))\n",
    "$$\n",
    "\n",
    "given the edge embedding we predict the binary label of the node using RandomForestClassifier\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## step 1\n",
    "Build a dataset different from cora :)  \n",
    "  \n",
    "We use AIDS[1][2] a dataset representing 2000 moleculas compounds, each moleculas is represented as a graph and each graph has an attribute indicating if the compound is active or inactive against HIV. \n",
    "\n",
    "\n",
    "<sub>[1] Riesen, K. and Bunke, H.: IAM Graph Database Repository for Graph Based Pattern Recognition and Machine Learning. In: da Vitora Lobo, N. et al. (Eds.), SSPR&SPR 2008, LNCS, vol. 5342, pp. 287-297, 2008.  \n",
    "[2] AIDS Antiviral Screen Data (2004)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
    "device = \"cpu\"\n",
    "dataset = \"AIDS\"\n",
    "data = TUDataset(\".\", name=dataset)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### The dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(data)\n",
    "print(\"number of classes: \",data.num_classes,\"\\t\\t(active),(inactive)\")\n",
    "print(\"number of features: \",data.num_features)\n",
    "print(\"number of edge labels: \",data.num_edge_labels)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "data1 = data[1]\n",
    "\n",
    "# extract edge list\n",
    "edge_list = data1.edge_index.t().numpy()\n",
    "print(edge_list[0:10])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# extract edge attributes\n",
    "edge_attr = data1.edge_attr.numpy()\n",
    "print(edge_attr[0:10])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import networkx as nx\n",
    "\n",
    "# build the graph\n",
    "graph1 = nx.Graph()\n",
    "\n",
    "for i in range(len(edge_list)):\n",
    "    u = edge_list[i][0]\n",
    "    v = edge_list[i][1]\n",
    "    graph1.add_edge(u,v,label=edge_attr[i])\n",
    "    \n",
    "print(graph1.edges(data=True))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "pos = nx.spring_layout(graph1)\n",
    "nx.draw(graph1,pos)\n",
    "nx.draw_networkx_edge_labels(graph1,pos,nx.get_edge_attributes(graph1,'label'))\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Use only one moleculas compound"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "data.data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "data = data[10]\n",
    "data"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### build the graph with\n",
    "train_mask, test_mask, val_mask"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "np.random.seed(10)\n",
    "# get the nodes\n",
    "nodes = data.edge_index.t().numpy()\n",
    "nodes = np.unique(list(nodes[:,0]) + list(nodes[:,1]))\n",
    "\n",
    "np.random.shuffle(nodes) # shuffle node order\n",
    "print(len(nodes))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# get train test and val sizes: (70% - 15% - 15%)\n",
    "train_size = int(len(nodes)*0.7)\n",
    "test_size = int(len(nodes)*0.85) - train_size\n",
    "val_size = len(nodes) - train_size - test_size"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# get train test and validation set of nodes\n",
    "train_set = nodes[0:train_size]\n",
    "test_set = nodes[train_size:train_size+test_size]\n",
    "val_set = nodes[train_size+test_size:]\n",
    "\n",
    "\n",
    "print(len(train_set),len(test_set),len(val_set))\n",
    "print(len(train_set)+len(test_set)+len(val_set) == len(nodes))\n",
    "\n",
    "print(\"train set\\t\",train_set[:10])\n",
    "print(\"test set \\t\",test_set[:10])\n",
    "print(\"val set  \\t\",val_set[:10])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# build test train val masks\n",
    "\n",
    "train_mask = torch.zeros(len(nodes),dtype=torch.long, device=device)\n",
    "for i in train_set:\n",
    "    train_mask[i] = 1.\n",
    "\n",
    "test_mask = torch.zeros(len(nodes),dtype=torch.long, device=device)\n",
    "for i in test_set:\n",
    "    test_mask[i] = 1.\n",
    "    \n",
    "val_mask = torch.zeros(len(nodes),dtype=torch.long, device=device)\n",
    "for i in val_set:\n",
    "    val_mask[i] = 1.\n",
    "    \n",
    "print(\"train mask \\t\",train_mask[0:15])\n",
    "print(\"test mask  \\t\",test_mask[0:15])\n",
    "print(\"val mask   \\t\",val_mask[0:15]) "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# remove from the data what do we not use.\n",
    "\n",
    "print(\"befor\\t\\t\",data)\n",
    "data.x = None\n",
    "data.edge_attr = None\n",
    "data.y = None"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# add masks\n",
    "data.train_mask = train_mask\n",
    "data.test_mask = test_mask\n",
    "data.val_mask = val_mask\n",
    "\n",
    "print(\"after\\t\\t\",data)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## step 2\n",
    "Execute Node2Vec to get node embeddings"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
    "\n",
    "model = Node2Vec(data.edge_index, embedding_dim=128, walk_length=20,\n",
    "             context_size=10, walks_per_node=10,\n",
    "             num_negative_samples=1, p=1, q=1, sparse=True).to(device)\n",
    "\n",
    "loader = model.loader(batch_size=128, shuffle=True, num_workers=4)\n",
    "optimizer = torch.optim.SparseAdam(list(model.parameters()), lr=0.01)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def train():\n",
    "    model.train()\n",
    "    total_loss = 0\n",
    "    for pos_rw, neg_rw in loader:\n",
    "        optimizer.zero_grad()\n",
    "        loss = model.loss(pos_rw.to(device), neg_rw.to(device))\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "        total_loss += loss.item()\n",
    "    return total_loss / len(loader)\n",
    "\n",
    "\n",
    "@torch.no_grad()\n",
    "def test():\n",
    "    model.eval()\n",
    "    z = model()\n",
    "    acc = model.test(z[data.train_mask], data.y[data.train_mask],\n",
    "                     z[data.test_mask], data.y[data.test_mask],\n",
    "                     max_iter=10)\n",
    "    return acc\n",
    "\n",
    "\n",
    "for epoch in range(1, 101):\n",
    "    loss = train()\n",
    "    #acc = test()\n",
    "    if epoch % 10 == 0:\n",
    "        print(f'Epoch: {epoch:02d}, Loss: {loss:.4f}')\n",
    "        "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "z = model()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### visualize node embedding"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# from tensor to numpy\n",
    "emb_128 = z.detach().cpu().numpy()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.decomposition import PCA\n",
    "# fit and transform using PCA\n",
    "pca = PCA(n_components=2)\n",
    "emb2d = pca.fit_transform(emb_128)\n",
    "\n",
    "\n",
    "plt.title(\"node embedding in 2D\")\n",
    "plt.scatter(emb2d[:,0],emb2d[:,1])\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Step 3\n",
    "Compute edge embedding"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# reload AIDS dataset \n",
    "# and pick the same graph\n",
    "\n",
    "\n",
    "dataset = \"AIDS\"\n",
    "data = TUDataset(\".\", name=dataset)\n",
    "data = data[10]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# convert edge attributes from categorical to numerical\n",
    "edge_attr_cat = data.edge_attr.numpy()\n",
    "print(\"Categorical edge attributes:\\n\",edge_attr_cat[:3])\n",
    "\n",
    "edge_attr = []\n",
    "for i in edge_attr_cat:\n",
    "    edge_attr.append(np.nonzero(i)[0][0])\n",
    "\n",
    "print(\"\\n\\nNumerical edge attributes:\\n\",edge_attr[:3])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# compute edge embedding\n",
    "\n",
    "edge_embedding = []\n",
    "for u,v in data.edge_index.t():\n",
    "    edge_embedding.append(np.mean([emb_128[u],emb_128[v]],0))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### visualize edge embedding\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "# fit and transform using PCA\n",
    "pca = PCA(n_components=2)\n",
    "edge_emb2d = pca.fit_transform(edge_embedding)\n",
    "\n",
    "\n",
    "\n",
    "df = pd.DataFrame(dict(edge_att=edge_attr))\n",
    "colors = {0:\"red\",1:\"blue\"}\n",
    "plt.title(\"edge embedding in 2D\")\n",
    "plt.scatter(edge_emb2d[:,0],edge_emb2d[:,1],c=df.edge_att.map(colors))\n",
    "plt.show()\n",
    "\n",
    "# not so good but we are using PCA to reduce the dim from 128 to 2"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Step 4\n",
    "Use RandomForestClassifier with 10-fold cross validation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.model_selection import cross_val_score\n",
    "from sklearn.ensemble import RandomForestClassifier"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "clf = RandomForestClassifier(max_depth=7,random_state=10)\n",
    "\n",
    "\n",
    "scores = cross_val_score(clf, edge_embedding, edge_attr, cv=10)\n",
    "np.mean(scores)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
